# coding=utf-8

# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.

# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

""" Utilities for PyTorch Transformer XL model.

    Directly adapted from https://github.com/kimiyoung/transformer-xl.

"""



from collections import defaultdict



import numpy as np



import torch

import torch.nn as nn

import torch.nn.functional as F



# CUDA_MAJOR = int(torch.version.cuda.split('.')[0])

# CUDA_MINOR = int(torch.version.cuda.split('.')[1])



class ProjectedAdaptiveLogSoftmax(nn.Module):

    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,

                 keep_order=False):

        super(ProjectedAdaptiveLogSoftmax, self).__init__()



        self.n_token = n_token

        self.d_embed = d_embed

        self.d_proj = d_proj



        self.cutoffs = cutoffs + [n_token]

        self.cutoff_ends = [0] + self.cutoffs

        self.div_val = div_val



        self.shortlist_size = self.cutoffs[0]

        self.n_clusters = len(self.cutoffs) - 1

        self.head_size = self.shortlist_size + self.n_clusters



        if self.n_clusters > 0:

            self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))

            self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))



        self.out_layers = nn.ModuleList()

        self.out_projs = nn.ParameterList()



        if div_val == 1:

            for i in range(len(self.cutoffs)):

                if d_proj != d_embed:

                    self.out_projs.append(

                        nn.Parameter(torch.FloatTensor(d_proj, d_embed))

                    )

                else:

                    self.out_projs.append(None)



            self.out_layers.append(nn.Linear(d_embed, n_token))

        else:

            for i in range(len(self.cutoffs)):

                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]

                d_emb_i = d_embed // (div_val ** i)



                self.out_projs.append(

                    nn.Parameter(torch.FloatTensor(d_proj, d_emb_i))

                )



                self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))



        self.keep_order = keep_order



    def _compute_logit(self, hidden, weight, bias, proj):

        if proj is None:

            logit = F.linear(hidden, weight, bias=bias)

        else:

            # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:

            proj_hid = F.linear(hidden, proj.t().contiguous())

            logit = F.linear(proj_hid, weight, bias=bias)

            # else:

            #     logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))

            #     if bias is not None:

            #         logit = logit + bias



        return logit



    def forward(self, hidden, labels=None, keep_order=False):

        '''

            Params:

                hidden :: [len*bsz x d_proj]

                labels :: [len*bsz]

            Return:

                if labels is None:

                    out :: [len*bsz] Negative log likelihood

                else:

                    out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary

            We could replace this implementation by the native PyTorch one

            if their's had an option to set bias on all clusters in the native one.

            here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138

        '''



        if labels is not None:

            labels = labels.view(-1)

            if hidden.size(0) != labels.size(0):

                raise RuntimeError('Input and labels should have the same size '

                                'in the batch dimension.')



        if self.n_clusters == 0:

            logit = self._compute_logit(hidden, self.out_layers[0].weight,

                                        self.out_layers[0].bias, self.out_projs[0])

            if labels is not None:

                out = -F.log_softmax(logit, dim=-1) \
                        .gather(1, labels.unsqueeze(1)).squeeze(1)

            else:

                out = F.log_softmax(logit, dim=-1)

        else:

            # construct weights and biases

            weights, biases = [], []

            for i in range(len(self.cutoffs)):

                if self.div_val == 1:

                    l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

                    weight_i = self.out_layers[0].weight[l_idx:r_idx]

                    bias_i = self.out_layers[0].bias[l_idx:r_idx]

                else:

                    weight_i = self.out_layers[i].weight

                    bias_i = self.out_layers[i].bias



                if i == 0:

                    weight_i = torch.cat(

                        [weight_i, self.cluster_weight], dim=0)

                    bias_i = torch.cat(

                        [bias_i, self.cluster_bias], dim=0)



                weights.append(weight_i)

                biases.append(bias_i)



            head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]



            head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)

            head_logprob = F.log_softmax(head_logit, dim=1)



            if labels is None:

                out = hidden.new_empty((head_logit.size(0), self.n_token))

            else:

                out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device)



            offset = 0

            cutoff_values = [0] + self.cutoffs

            for i in range(len(cutoff_values) - 1):

                l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]



                if labels is not None:

                    mask_i = (labels >= l_idx) & (labels < r_idx)

                    indices_i = mask_i.nonzero().squeeze()



                    if indices_i.numel() == 0:

                        continue



                    target_i = labels.index_select(0, indices_i) - l_idx

                    head_logprob_i = head_logprob.index_select(0, indices_i)

                    hidden_i = hidden.index_select(0, indices_i)

                else:

                    hidden_i = hidden



                if i == 0:

                    if labels is not None:

                        logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)

                    else:

                        out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]

                else:

                    weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]



                    tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)

                    tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)

                    cluster_prob_idx = self.cutoffs[0] + i - 1  # No probability for the head cluster

                    if labels is not None:

                        logprob_i = head_logprob_i[:, cluster_prob_idx] \
                                + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)

                    else:

                        logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i

                        out[:, l_idx:r_idx] = logprob_i



                if labels is not None:

                    if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:

                        out.index_copy_(0, indices_i, -logprob_i)

                    else:

                        out[offset:offset+logprob_i.size(0)].copy_(-logprob_i)

                    offset += logprob_i.size(0)



        return out





    def log_prob(self, hidden):

        r""" Computes log probabilities for all :math:`n\_classes`

        From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py

        Args:

            hidden (Tensor): a minibatch of examples

        Returns:

            log-probabilities of for each class :math:`c`

            in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a

            parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.

        Shape:

            - Input: :math:`(N, in\_features)`

            - Output: :math:`(N, n\_classes)`

        """

        if self.n_clusters == 0:

            logit = self._compute_logit(hidden, self.out_layers[0].weight,

                                        self.out_layers[0].bias, self.out_projs[0])

            return F.log_softmax(logit, dim=-1)

        else:

            # construct weights and biases

            weights, biases = [], []

            for i in range(len(self.cutoffs)):

                if self.div_val == 1:

                    l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

                    weight_i = self.out_layers[0].weight[l_idx:r_idx]

                    bias_i = self.out_layers[0].bias[l_idx:r_idx]

                else:

                    weight_i = self.out_layers[i].weight

                    bias_i = self.out_layers[i].bias



                if i == 0:

                    weight_i = torch.cat(

                        [weight_i, self.cluster_weight], dim=0)

                    bias_i = torch.cat(

                        [bias_i, self.cluster_bias], dim=0)



                weights.append(weight_i)

                biases.append(bias_i)



            head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]

            head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)



            out = hidden.new_empty((head_logit.size(0), self.n_token))

            head_logprob = F.log_softmax(head_logit, dim=1)



            cutoff_values = [0] + self.cutoffs

            for i in range(len(cutoff_values) - 1):

                start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1]



                if i == 0:

                    out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]

                else:

                    weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]



                    tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i)

                    tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)



                    logprob_i = head_logprob[:, -i] + tail_logprob_i

                    out[:, start_idx, stop_idx] = logprob_i



            return out





class LogUniformSampler(object):

    def __init__(self, range_max, n_sample):

        """

        Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py

            `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`



        expected count can be approximated by 1 - (1 - p)^n

        and we use a numerically stable version -expm1(num_tries * log1p(-p))



        Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run

        """

        with torch.no_grad():

            self.range_max = range_max

            log_indices = torch.arange(1., range_max+2., 1.).log_()

            self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]



            self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()



        self.n_sample = n_sample



    def sample(self, labels):

        """

            labels: [b1, b2]

        Return

            true_log_probs: [b1, b2]

            samp_log_probs: [n_sample]

            neg_samples: [n_sample]

        """



        # neg_samples = torch.empty(0).long()

        n_sample = self.n_sample

        n_tries = 2 * n_sample



        with torch.no_grad():

            neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()

            device = labels.device

            neg_samples = neg_samples.to(device)

            true_log_probs = self.log_q[labels].to(device)

            samp_log_probs = self.log_q[neg_samples].to(device)

            return true_log_probs, samp_log_probs, neg_samples



def sample_logits(embedding, bias, labels, inputs, sampler):

    """

        embedding: an nn.Embedding layer

        bias: [n_vocab]

        labels: [b1, b2]

        inputs: [b1, b2, n_emb]

        sampler: you may use a LogUniformSampler

    Return

        logits: [b1, b2, 1 + n_sample]

    """

    true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)

    n_sample = neg_samples.size(0)

    b1, b2 = labels.size(0), labels.size(1)

    all_ids = torch.cat([labels.view(-1), neg_samples])

    all_w = embedding(all_ids)

    true_w = all_w[: -n_sample].view(b1, b2, -1)

    sample_w = all_w[- n_sample:].view(n_sample, -1)



    all_b = bias[all_ids]

    true_b = all_b[: -n_sample].view(b1, b2)

    sample_b = all_b[- n_sample:]



    hit = (labels[:, :, None] == neg_samples).detach()



    true_logits = torch.einsum('ijk,ijk->ij',

        [true_w, inputs]) + true_b - true_log_probs

    sample_logits = torch.einsum('lk,ijk->ijl',

        [sample_w, inputs]) + sample_b - samp_log_probs

    sample_logits.masked_fill_(hit, -1e30)

    logits = torch.cat([true_logits[:, :, None], sample_logits], -1)



    return logits

